#THIS FILE PERFORMS BUILDS A STEERABLE AND CONV CNP ARCHITECTURE
#HYPERPARAMETERS:
#1. Dimension of covariance estimation: 3 - diagonal covariance estimation
#2. Number of layers: 5
#LIBRARIES:
#Tensors:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as utils
#E(2)-steerable CNNs - librar"y:
from e2cnn import gspaces
from e2cnn import nn as G_CNN
#Plotting in 2d/3d:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import Ellipse
from matplotlib.colors import Normalize
import matplotlib.cm as cm
#Tools:
import datetime
import sys
#Own files:
import Kernel_and_GP_tools as GP
import My_Tools
import Steerable_CNP_Models as My_Models
#HYPERPARAMETERS:
#Set default as double:
torch.set_default_dtype(torch.float)
#Scale for plotting with plt quiver
quiver_scale=15
if torch.cuda.is_available():
device = torch.device("cuda:0")
print("Running on the GPU")
else:
device = torch.device("cpu")
print("Running on the CPU")
def SETUP_EXP_4_Cyclic_GP_div_free(Training_par,N=8,batch_size=3):
G_act = gspaces.Rot2dOnR2(N=N)
feat_type_in=G_CNN.FieldType(G_act, [G_act.irrep(1)])
GP_train_data_loader,GP_test_data_loader=GP.load_2d_GP_data(Id="37845",batch_size=batch_size)
GP_parameters={'l_scale':1,'sigma_var':1, 'kernel_type':"div_free",'obs_noise':1e-4,'B':None,'Ker_project':False}
Operator_par={'train_data_loader': GP_train_data_loader,'test_data_loader': GP_test_data_loader}
#Define the grid, the kernel parameters and the encoder:
grid_dict={'x_range':[-3,3],'y_range':[-3,3],'n_x_axis':20,'n_y_axis':20}
kernel_dict_emb={'sigma_var':1,'kernel_type':"rbf",'Ker_project':False}
encoder=My_Models.Steerable_Encoder(**grid_dict,kernel_dict=kernel_dict_emb,normalize=True)
#Define the kernel parameters for the kernel smoother:
kernel_dict_out={'sigma_var':1,'kernel_type':"rbf",'B':None,'Ker_project':False}
#---------------------Conv CNP decoder-----------------------------
conv_decoder=nn.Sequential(nn.Conv2d(3,16,kernel_size=5,stride=1,padding=2),
nn.ReLU(),
nn.Conv2d(16,16,kernel_size=7,stride=1,padding=3),
nn.ReLU(),
nn.Conv2d(16,16,kernel_size=5,stride=1,padding=2),
nn.ReLU(),
nn.Conv2d(16,12,kernel_size=7,stride=1,padding=3),
nn.ReLU(),
nn.Conv2d(12,5,kernel_size=5,stride=1,padding=2))
#---------------------Steerable CNP decoder-----------------------------
#Define the f||eature types:
psd_rep,_=My_Tools.get_pre_psd_rep(G_act)
feat_type_out=G_CNN.FieldType(G_act,[G_act.irrep(1),psd_rep])
feat_types=[G_CNN.FieldType(G_act, [G_act.trivial_repr,G_act.irrep(1)]),
G_CNN.FieldType(G_act, 2*[G_act.regular_repr]),
G_CNN.FieldType(G_act,2*[G_act.regular_repr]),
G_CNN.FieldType(G_act,2*[G_act.regular_repr]),
G_CNN.FieldType(G_act,[G_act.regular_repr]),
feat_type_out]
#Define the kernel sizes:
kernel_sizes=[5,7,5,7,5]
geom_decoder=My_Models.Steerable_Decoder(feat_types,kernel_sizes)
#Get the convcnp:
conv_cnp=My_Models.Steerable_CNP(feature_in=feat_type_in,dim_cov_est=3,G_act=G_act,encoder=encoder,decoder=conv_decoder,kernel_dict_out=kernel_dict_out)
geom_cnp=My_Models.Steerable_CNP(feature_in=feat_type_in,dim_cov_est=3,G_act=G_act,encoder=encoder,decoder=geom_decoder,kernel_dict_out=kernel_dict_out)
#Send the models to the correct devices:
conv_cnp=conv_cnp.to(device)
geom_cnp=geom_cnp.to(device)
#Create Operator model:
Conv_CNP_Operator=My_Models.Steerable_CNP_Operator(conv_cnp,**Training_par,**Operator_par)
Geom_CNP_Operator=My_Models.Steerable_CNP_Operator(geom_cnp,**Training_par,**Operator_par)
return(Conv_CNP_Operator,Geom_CNP_Operator,GP_parameters)
n_epochs=30
n_iterat=1000
train=False
evaluate=True
n_tests=400
#------------------------------------
#-----Experiment 4.1:
#----------------------------------------
Training_par={'Max_n_context_points':50,'n_epochs':n_epochs,'n_plots':None,'n_iterat_per_epoch':n_iterat,
'learning_rate':1e-4}
Conv_CNP,Geom_CNP,GP_parameters=SETUP_EXP_4_Cyclic_GP_div_free(Training_par,N=8,batch_size=3)
if train:
filename_41="Exp_4_1"
starttime=datetime.datetime.today()
print("Start training experiment 4.1: ", starttime)
loss_Geom_CNP=Geom_CNP.train(filename="Initial_ziz_exp_1507/"+filename_41+"_Steerable_CNP_")
loss_ConvCNP=Conv_CNP.train(filename="Initial_ziz_exp_1507/"+filename_41+"_Conv_CNP_")
endtime=datetime.datetime.today()
print("Duration of training on device: ",device,": ",endtime-starttime)
if evaluate:
Conv_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_4/Exp_4_1_Conv_CNP__2020_07_16_22_58",map_location=torch.device('cpu')))
Geom_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_4/Exp_4_1_Steerable_CNP__2020_07_16_22_29",map_location=torch.device('cpu')))
X,Y=next(iter(Conv_CNP.test_data_loader))
n_context_points=torch.randint(size=[],low=2,high=Conv_CNP.Max_n_context_points)
x_context,y_context,x_target,y_target=My_Tools.Rand_Target_Context_Splitter(X[0],Y[0],n_context_points)
Conv_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=None,title="Exp 4.1: ConvCNP")
Geom_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=GP_parameters,title="Exp 4.1: SteerCNP")
print("Exp. 4.1: Log-LL Steer.: ",Geom_CNP.test(n_tests))
print("Exp. 4.1: Log-LL Conv.: ",Conv_CNP.test(n_tests))
Conv_CNP.test_equivariance_model(plot=True,title="ConvCNP")
Geom_CNP.test_equivariance_model(plot=True,title="SteerCNP")
#------------------------------------
#-----Experiment 4.2:
#----------------------------------------
Training_par={'Max_n_context_points':50,'n_epochs':n_epochs,'n_plots':None,'n_iterat_per_epoch':n_iterat,
'learning_rate':1e-3}
Conv_CNP,Geom_CNP,GP_parameters=SETUP_EXP_4_Cyclic_GP_div_free(Training_par,N=4,batch_size=4)
if train:
filename_42="Exp_4_2"
starttime=datetime.datetime.today()
print("Start training experiment 4.2: ", starttime)
loss_Geom_CNP=Geom_CNP.train(filename="Initial_ziz_exp_1507/"+filename_42+"_Steerable_CNP_")
loss_ConvCNP=Conv_CNP.train(filename="Initial_ziz_exp_1507/"+filename_42+"_Conv_CNP_")
endtime=datetime.datetime.today()
print("Duration of training on device: ",device,": ",endtime-starttime)
if evaluate:
Conv_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_4/Exp_4_2_Conv_CNP__2020_07_17_00_27",map_location=torch.device('cpu')))
Geom_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_4/Exp_4_2_Steerable_CNP__2020_07_16_23_45",map_location=torch.device('cpu')))
X,Y=next(iter(Conv_CNP.test_data_loader))
n_context_points=torch.randint(size=[],low=2,high=Conv_CNP.Max_n_context_points)
x_context,y_context,x_target,y_target=My_Tools.Rand_Target_Context_Splitter(X[0],Y[0],n_context_points)
Conv_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=None,title="Exp 4.2: ConvCNP")
Geom_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=GP_parameters,title="Exp 4.2: SteerCNP")
print("Exp. 4.2: Log-LL Steer.: ",Geom_CNP.test(n_tests))
print("Exp. 4.2: Log-LL Conv.: ",Conv_CNP.test(n_tests))
Conv_CNP.test_equivariance_model(plot=True,title="ConvCNP")
Geom_CNP.test_equivariance_model(plot=True,title="SteerCNP")
#------------------------------------
#-----Experiment 4.3:
#----------------------------------------
Training_par={'Max_n_context_points':50,'n_epochs':n_epochs,'n_plots':None,'n_iterat_per_epoch':n_iterat,
'learning_rate':1e-4}
Conv_CNP,Geom_CNP,GP_parameters=SETUP_EXP_4_Cyclic_GP_div_free(Training_par,N=4,batch_size=1)
if train:
filename_13="Exp_4_3"
starttime=datetime.datetime.today()
print("Start training experiment 4.3: ", starttime)
loss_Geom_CNP=Geom_CNP.train(filename="Initial_ziz_exp_1507/"+filename_13+"_Steerable_CNP_")
loss_ConvCNP=Conv_CNP.train(filename="Initial_ziz_exp_1507/"+filename_13+"_Conv_CNP_")
endtime=datetime.datetime.today()
print("Duration of training on device: ",device,": ",endtime-starttime)
if evaluate:
Conv_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_4/Exp_4_3_Conv_CNP__2020_07_17_00_55",map_location=torch.device('cpu')))
Geom_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_4/Exp_4_3_Steerable_CNP__2020_07_17_00_42",map_location=torch.device('cpu')))
X,Y=next(iter(Conv_CNP.test_data_loader))
n_context_points=torch.randint(size=[],low=2,high=Conv_CNP.Max_n_context_points)
x_context,y_context,x_target,y_target=My_Tools.Rand_Target_Context_Splitter(X[0],Y[0],n_context_points)
Conv_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=None,title="Exp 4.3: ConvCNP")
Geom_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=GP_parameters,title="Exp 4.3: SteerCNP")
print("Exp. 4.3: Log-LL Steer.: ",Geom_CNP.test(n_tests))
print("Exp. 4.3: Log-LL Conv.: ",Conv_CNP.test(n_tests))
Conv_CNP.test_equivariance_model(plot=True,title="ConvCNP")
Geom_CNP.test_equivariance_model(plot=True,title="SteerCNP")